import torch
import numpy as np
from itertools import repeat
import higra as hg

def get_ultrametric(feat, rays, data_pairs, seg_masks=None, visualize=False, pix_ids=None, img_wh=None):
    # data pairs : positive pairs idx 
    N, C = feat.shape
    Num_masks, Num_pairs, D = data_pairs.shape
    assert D == 2
    
    pixs_distances = ((rays.unsqueeze(0) - rays.unsqueeze(1)) ** 2).sum(-1)
    pixs_topk = torch.topk(pixs_distances, 11, largest=False)[1][:, 1:]
    
    data_pairs = data_pairs.reshape(-1, 2)
    srcs = np.array([i for i in range(len(pixs_topk)) for _ in repeat(None, 10)])
    tgts = pixs_topk.reshape(-1)
    
    data = (feat[srcs] * feat[tgts]).sum(-1)
    graph_edge_lengths =  1 - data
    
    if seg_masks is not None:
        on_boundary = torch.any(seg_masks[:, srcs] != seg_masks[:, tgts], dim=0)
        graph_edge_lengths[on_boundary] += 10
    
    tree, altitudes = hg.bpt_canonical((srcs, tgts.cpu().numpy(), len(feat)), graph_edge_lengths.detach().cpu().numpy().astype(float))
    tree.lowest_common_ancestor_preprocess()
    edge_idx = np.zeros(data_pairs.shape[0], dtype=np.int64)
    for i, (v1, v2) in enumerate(data_pairs):
        edge_idx[i] = tree.lowest_common_ancestor(v1, v2)
    
    edge_idx[edge_idx < N] = N
    edge_idx = edge_idx - N
    mst_map = tree.mst_edge_map
    edge_idx = mst_map[edge_idx]
        
    ultrametric = data[edge_idx]
    ultrametric[data_pairs[:, 0] == data_pairs[:, 1]] = 1
    ultrametric = ultrametric.reshape(Num_masks, Num_pairs)

    return ultrametric

def get_tree(feat, k_index):
    """
    Return the hierarchical tree for the feature field
    Input:
        feat: point cloud of features
        k_index: Indices of nearest neighbors for each point
        distance: Max distance threshold
    """
    num_points, k = k_index.shape
    rows = torch.arange(num_points).repeat(k)
    cols = k_index.transpose(0, 1).reshape(-1)

    # srcs = np.array([i for i in range(num_points) for _ in repeat(None, k)])
    # tgts = k_index.reshape(-1)

    # Batch this to avoid memory overflow
    data = np.zeros(len(rows))
    BATCH_SIZE = 1_000_000
    for i in range(0, len(rows), BATCH_SIZE):
        data[i:i+BATCH_SIZE] = (((feat[rows][i:i+BATCH_SIZE] - feat[cols][i:i+BATCH_SIZE])**2).sum(-1) + 1e-8).sqrt()
        # data[i:i+BATCH_SIZE] = 1 - (feat[rows][i:i+BATCH_SIZE] * feat[cols][i:i+BATCH_SIZE]).sum(-1)
    # data = (((feat[rows] - feat[cols]) ** 2).sum(-1) + 1e-8).sqrt()

    graph_edge_lengths = data

    g = hg.UndirectedGraph(num_points)
    g.add_edges(rows, cols)
    tree, altitudes = hg.bpt_canonical(g, graph_edge_lengths)
    tree, altitudes = hg.canonize_hierarchy(tree, altitudes)
    
    # TODO: filter the tree
    return tree, altitudes
